#####################################
## Simulations: replication        ##
## NOTE: computationally intensive ##
## portions of the code have been  ##
## commented out, with correspon-  ##
## ding output cached, for speed.  ##
## All code is seeded, however, to ##
## allow for full replication.     ##
#####################################

## Load required packages and helper functions
library(NetMix)
library(tidyverse)
source("Code/NetSim2.R")
source("Code/GenSimData.R")

## Define common network characteristics
nnodes <- 200
directed <- TRUE
n_sim <- 100
## Define three learning scenarios
## Blockmodels
BM_easy <- matrix(qlogis(c(0.85, 0.01, 0.01, 0.99)), ncol = 2)

BM_med <- matrix(qlogis(c(0.65, 0.2, 0.35, 0.75)), ncol = 2)

BM_hard <- matrix(qlogis(c(0.65, 0.5, 0.4, 0.45)), ncol = 2)
## Monadic coefficient arrays
beta_easy <- list(array(c(-4.5, -4.5, ##Intercepts
                          0.0, 0.0), ## Predictor coefficients
                        c(2, 2)),
                  array(c(-4.5, -4.5,
                          0.0, 0.0),
                        c(2, 2)))
beta_med <- list(array(c(0.05, 0.75,
                         -0.75,  -1.0),
                       c(2, 2)),
                 array(c(-0.05, 0.55,
                         -.75,  0.75),
                       c(2, 2)))
beta_hard <- list(array(c(0.0, 0.0,
                          -0.75,  -1.0),
                        c(2, 2)),
                  array(c(0.0, 0.0,
                          -0.75,  0.75),
                        c(2, 2)))

## Generate networks under three scenarios
set.seed(831213)
easy_net <- NetSim2(BLK = 2
                     ,NODE = nnodes
                     ,STATE = 2
                     ,TIME = 9
                     ,DIRECTED = directed
                     ,N_PRED=c(1, 1)
                     ,B =  BM_easy
                     ,beta_arr = beta_easy
                     ,sVec = rep(c(1,2),c(5,4))
                     ,gamma_vec = c(0.1))
set.seed(831213)
med_net <- NetSim2(BLK = 2
                     ,NODE = nnodes
                     ,STATE = 2
                     ,TIME = 9
                     ,DIRECTED = directed
                     ,N_PRED=c(1, 1)
                     ,B =  BM_med
                     ,beta_arr = beta_med
                     ,sVec = rep(c(1,2),c(5,4))
                     ,gamma_vec = c(0.1))
set.seed(831213)
hard_net <- NetSim2(BLK = 2
                     ,NODE = nnodes
                     ,STATE = 2
                     ,TIME = 9
                     ,DIRECTED = directed
                     ,N_PRED=c(1, 1)
                     ,B =  BM_hard
                     ,beta_arr = beta_hard
                     ,sVec = rep(c(1,2),c(5,4))
                     ,gamma_vec = c(0.1))
## Estimate models under three scenarios (cached)
#dat_easy <- GenSimData(n_node=nnodes, SynthNets=list(easy_net), mu_b=c(5, -5), alpha=0.5, lda_init=TRUE, type="Easy", tol=5e-3)[[1]]
#dat_med <- GenSimData(n_node=nnodes,SynthNets=list(med_net), mu_b=c(.5, -1.0), alpha=0.5, lda_init=TRUE, type="Realistic", tol=5e-3)[[1]]
#dat_med$pred_data$Pred <- 1-dat_med$pred_data$Pred #Posterior is rotation invariant (label switching); flip to match true labels.  
#dat_hard <- GenSimData(n_node=nnodes,SynthNets=list(hard_net), mu_b=c(5.0, -5.0), alpha=0.5, lda_init=TRUE, type="Hard", tol=1e-3)[[1]] 
#save(dat_easy, dat_med, dat_hard, file="Cache/SimScenarios.RData")
load("Cache/SimScenarios.RData")
pred_data <- do.call(rbind, list(dat_easy$pred_data, dat_med$pred_data, dat_hard$pred_data))
pred_data <- pred_data %>% 
  mutate(Network = factor(Network, c("Easy","Realistic","Hard"))) %>% 
  mutate(Network = recode(Network, Realistic = "Medium"))

## Blockmodels
bmodels  <- do.call(rbind,lapply(list(dat_easy, dat_med, dat_hard),
                                 function(x){
                                   mat <- plogis(x$BlockModel)
                                   df <- data.frame(Prob = c(mat),
                                                    GroupSender = factor(rep(c("1", "2"), 2), levels=c("2", "1")),
                                                    GroupReceiver = factor(rep(c("1", "2"), each=2),levels=c("1", "2"))
                                   )
                                   return(df)
                                 }))
bmodels$Network <- factor(rep(c("Easy", "Medium", "Hard"), each = 4), levels =c("Easy", "Medium", "Hard"))
bmodels$Truth <- c(c(plogis(BM_easy)), c(plogis(BM_med)), c(plogis(BM_hard)))
# 

## Figure SI1
ggplot(filter(pred_data,Group == 1),aes(x=Truth)) +
  facet_wrap(~Network)+
  geom_density(fill="gray60") +
  ylab("") +
  xlab("Membership in Group 1") +
  theme_bw()

## Figure SI2: top row
ggplot(pred_data, aes(x=Truth, y=Pred))+
  geom_point(size=2.5, alpha=0.7) +
  facet_grid(~Network) +
  theme_bw() +
  xlim(c(0,1)) +
  ylim(c(0,1)) +
  ylab("Estimate") +
  xlab("True mixed-membership")

## Figure  SI2, bottom row
ggplot(bmodels, aes(y = GroupSender, x = GroupReceiver, fill = Truth, label=Truth))+
  facet_grid(~Network)+
  labs(fill="True Edge Probability")+
  geom_tile() +
  geom_text(col="white", fontface = "bold")+
  scale_fill_gradient(low="grey90", high="gray10") +
  xlab("Receiver Group") +
  ylab("Sender Group") +
  theme_bw() +
  theme(legend.position="bottom")

## Correlations between true and estimated MMs
pred_data %>% group_by(Network) %>% summarize(cor(Truth,Pred))

## Generate 100 synthetic networks (cached)
#SynthNets <- GenSimData(n_node=nnodes, n_sim=n_sim, seed=123, B=BM_med, beta_arr=beta_med, nets_only=TRUE)  
#save(SynthNets, file="Cache/SynthNets.RData")
load("Cache/SynthNets.RData")

## Simulation for effects (cached)
#effect_sim <- GenSimData(n_node=nnodes, n_sim=n_sim, seed=123, B=BM_med, beta_arr=beta_med, hess = FALSE, SynthNets = SynthNets)  
#save(effect_sim, file="Cache/EffectSims.RData")
load("Cache/EffectSims.RData")

## Reg. coefficient estimation accuracy
fxs <- do.call(rbind,lapply(effect_sim,
                            function(x){
                              beta_vec <- x$MonadCoef[2,2,]
                              fx <- beta_vec %*% t(model.matrix(~-1+as.factor(SynthNets[[1]]$sVec)))
                              data.frame(Year = as.factor(1:9),
                                         Effect = c(fx),
                                         True = rep(c(beta_med[[1]][2,2],
                                                      beta_med[[2]][2,2]),c(5,4)),
                                         Type = "Predictor Coefficient")
                            }))
intcpt <- do.call(rbind,lapply(effect_sim,
                               function(x){
                                 beta_vec <- x$MonadCoef[1,2,]
                                 fx <- beta_vec %*% t(model.matrix(~-1+as.factor(SynthNets[[1]]$sVec)))
                                 data.frame(Year = as.factor(1:9),
                                            Effect = c(fx),
                                            True = rep(c(beta_med[[1]][2,1],
                                                         beta_med[[2]][2,1]),c(5,4)),
                                            Type = "Intercept")
                               }))
all_fx <- rbind(fxs, intcpt)
all_fx$Type <- factor(all_fx$Type, levels=c("Predictor Coefficient","Intercept"))

## Figure SI3: Distribution of estimated effect sizes, and true effect size
ggplot(all_fx, aes(x=Year, y=Effect))+
  facet_wrap(~Type, scale="free_y") + 
  geom_boxplot(outlier.shape = NA) +
  geom_point(aes(y=True), col="darkred",pch=4, size=2.5) +
  ylab("Parameter Values for Group 2")+
  theme_bw()



## Comparison between no-covariate and non-dynamic mmsbm (cached)
#net3.model.sim.nc <- GenSimData(nnodes, n_sim, 123, B=BM_med, beta_arr=beta_med, hess = FALSE, f.dyad=Y ~ 1, f.monad=~1, SynthNets = SynthNets)  
#save(net3.model.sim.nc, file="Cache/SynthNetResultsNC.RData")
load("Cache/SynthNetResultsNC.RData")

## Define values for MMSBM
n_prior <- 5
a_p <- plogis(c(.5, -1.0)) * n_prior
b_p <- n_prior - a_p
## Estimate MMSBM models (cached)
# cg_mmsbm_res <- lapply(SynthNets,
#                        function(net, a = a_p, b=b_p, alpha_p = 0.5){
#                          soc_mats <- lapply(seq.int(9),
#                                             function(x){
#                                               dat <- subset(net$dyad.data, year == x)
#                                               datn <- subset(net$monad.data, year == x)
#                                               mat <- matrix(NA, ncol=net$NODE, nrow=net$NODE)
#                                               mat[cbind(dat$node1,dat$node2)] <- dat$Y
#                                               rownames(mat) <- colnames(mat) <- datn$node
#                                               return(mat)
#                                             })
#                          dyads <- split(net$dyad.data[,c("node1","node2")], net$dyad.data[,"year"])
#                          edges <- split(net$dyad.data[,c("Y")], net$dyad.data[,"year"])
#                          temp_res <- vector("list", 9)
#                          for(i in 1:9){
#                            lda_beta_prior <- lapply(list(b,a),
#                                                     function(prior){
#                                                       mat <- matrix(prior[2], net$BLK, net$BLK)
#                                                       diag(mat) <- prior[1]
#                                                       return(mat)
#                                                     })
#                            set.seed(831213)
#                            ret <- lda::mmsb.collapsed.gibbs.sampler(network = soc_mats[[i]],
#                                                                     K = net$BLK,
#                                                                     num.iterations = 100L,
#                                                                     burnin = 50L,
#                                                                     alpha = alpha_p,
#                                                                     beta.prior = lda_beta_prior)
#                            MixedMembership <- prop.table(ret$document_expects, 2)
#                            colnames(MixedMembership) <- colnames(soc_mats[[i]])
#                            BlockModel <- ret$blocks.pos/(ret$blocks.pos + ret$blocks.neg)#
#                            temp_res[[i]] <- list(BlockModel = BlockModel,
#                                                  MixedMembership = MixedMembership)
#                          }
#                          block_models <- lapply(temp_res, function(x)x$BlockModel)
#                          target_ind <- which.max(sapply(soc_mats, ncol))
#                          perms_temp <- NetMix:::.findPerm(block_models, target_mat = block_models[[target_ind]], use_perms = TRUE)
#                          phis_temp <- lapply(temp_res, function(x)x$MixedMembership)
#                          phi_init_t <- do.call(cbind,mapply(function(phi,perm){perm %*% phi},
#                                                             phis_temp, perms_temp, SIMPLIFY = FALSE))
#                          loss.mat.phi<- phi_init_t %*% net$pi_vecs
#                          phi_ord <- clue::solve_LSAP(t(loss.mat.phi), TRUE)
#                          return(phi_init_t[phi_ord, ])
#                        })
#save(cg_mmsbm_res, file="Cache/SynthNetResultsGibbs.RData")
load("Cache/SynthNetResultsGibbs.RData")

error_df <- do.call(rbind,lapply(seq.int(n_sim), function(i){
  mmsbm_error <- mapply(function(x,y){sqrt(sum((x-y)^2))},as.data.frame(t(SynthNets[[i]]$pi_vecs)), as.data.frame((cg_mmsbm_res[[i]])))
  our_error <- mapply(function(x,y){sqrt(sum((x-y)^2))},as.data.frame(t(SynthNets[[i]]$pi_vecs)), as.data.frame(net3.model.sim.nc[[i]]$MixedMembership))
  time <- rep(1:9, each = nnodes)
  mmsbm_error <- tapply(mmsbm_error, list(time), mean)
  our_error <- tapply(our_error, list(time), mean)
  data.frame(Error = c(mmsbm_error, our_error),
             Model = rep(c("MMSBM","DynMMSBM"), each = 9),
             Time = factor(1:9),
             sim = i)
}))

## Figure SI4
ggplot(error_df, aes(x=Model, y=Error)) +
  geom_boxplot() +
  ylab("Mean L2 Error")+
  theme_bw()





